R/supervised learning.R

Defines functions set_formula

################################################################################
#
# Package: machinelearningtools
# Purpose: Provide convenience functions for machine learning with caret
#
################################################################################


################################################################################
# set model input > formula
################################################################################
set_formula <- function(target_label, features) {

  features %>%
    paste(collapse = " + ") %>%
    paste(target_label, "~", .) %>%
    as.formula(env = .GlobalEnv)

}


################################################################################
# get_model_metrics:
#   calculate training set performance:
#   mean & sd for all model objects in model_list
#
# set color by
##
##  palette:
##    models.list %>% get_model_metrics(palette = "Dark2")
##
##  color codes:
##    models.list %>% get_model_metrics(
##      colors = c("#4DAF4A", "#E41A1C", "#FF7F00", "#377EB8"))
##
##  colors: "#4DAF4A" green "#377EB8" blue "#E41A1C" red "#FF7F00" orange
##
################################################################################
get_model_metrics <- function(
  models_list, target_label = NULL, testing_set = NULL,
  median_sort = FALSE, reverse = FALSE,
  palette = "Set1", colors = NULL,
  boxplot_fill = "grey95", boxplot_color = "grey25") {

  require(RColorBrewer)

  # retrieve target.label & testing.set from models_list
  target.label <- if (!is.null(target_label)) target_label else models_list$target.label

  # set testing set to argument > from models_list > NULL if empty
  if (!is.null(testing_set)) {
    testing.set <- testing_set
  } else if (is.null(models_list$testing.set)) {
    testing.set <- NULL
  } else if (nrow(models_list$testing.set) != 0) {
    testing.set <- models_list$testing.set
  } else { # e.g. if testingset exits but 0 rows
    testing.set <- NULL
  }

  # remove target.label + testing.set from models.list
  if (!is.null(models_list$target.label)) {
    models_list %<>% purrr::list_modify("target.label" = NULL)
  }
  if (!is.null(models_list$testing.set)) {
    models_list %<>% purrr::list_modify("testing.set" = NULL)
  }

  target <- models_list[[1]]$trainingData$.outcome

  if (is.factor(target)) {

    metric1 = "Accuracy"
    metric2 = "Kappa"
    metric3 = NULL
    metric1.descending = FALSE
    metric2.descending = FALSE
    metric3.descending = FALSE

  } else if (is.numeric(target)) {

    metric1 = "RMSE"
    metric2 = "Rsquared"
    metric3 = "R"
    metric1.descending = TRUE
    metric2.descending = FALSE
    metric3.descending = FALSE
  }

  if (reverse) {
    metric1.descending
    metric2.descending
    metric3.descending
  }

  ### get metrics from original resamples' folds
  resamples.values <- models_list %>% caret::resamples() %>% .$values %>%
    # select_if(is.numeric) %>%
    # retrieve RMSE, Rsquared but not MAE
    ## tricky: select without dplyr:: prefix does NOT work
    # dplyr::select(ends_with("RMSE"), ends_with("Rsquared"))
    dplyr::select(ends_with(metric1), ends_with(metric2)) %>%
    # calculate R from R-squared variables
    dplyr::mutate(
      across(
        .cols = ends_with(metric2), # R-squared
        .fns = sqrt,
        .names = "{.col}.R"
      )
    ) %>%
    set_names(gsub(paste0(metric2, ".R"), "R", names(.)))

  ### calculate mean and sd for each metric
  metric1.training <- get_metric_from_resamples(
    resamples.values, metric1, median_sort)
  metric2.training <- get_metric_from_resamples(
    resamples.values, metric2, median_sort)
  metric3.training <- switch (
    !is.null(metric3), # instead of: is.numeric(target) = test on a vector
    get_metric_from_resamples(resamples.values, metric3, median_sort),
    NULL
  )

  ### visualize the resampling distribution from cross-validation
  metric1.resamples.boxplots <- visualize_resamples_boxplots(
    resamples.values, metric1, palette, colors = colors, metric1.descending)
  metric2.resamples.boxplots <- visualize_resamples_boxplots(
    resamples.values, metric2, palette, colors = colors, metric2.descending)

  metric3.resamples.boxplots <- switch (
    !is.null(metric3),
    visualize_resamples_boxplots(
      resamples.values, metric3, palette, colors = colors, metric3.descending),
    NULL)

  if (!is.null(testing.set)) {
    metrics.testing <- get_testingset_performance(
      models_list, target.label, testing.set
    )
  } else {
    metrics.testing <- NULL
  }

  if (is.factor(target)) { # classification

    benchmark.all <- merge(metric1.training, metric2.training, by = "model") %>%
      {
        if (!is.null(metrics.testing)) {
          # tricky: within conditional {} block, must reference to LHS (.)
          merge(., metrics.testing, by = "model") %>%
            arrange(desc(Acc.testing))
        } else {
          .
        }
      } %>%
      as_tibble(.)

  } else if (is.numeric(target)) { # regression

    benchmark.all <- merge(metric1.training, metric2.training, by = "model") %>%
      merge(metric3.training, by = "model") %>%
      {
        if (!is.null(metrics.testing)) {
          # tricky: within conditional {} block, must reference to LHS (.)
          merge(., metrics.testing, by = "model") %>%
            dplyr::mutate(RMSE.delta = RMSE.testing - RMSE.mean) %>%
            arrange(RMSE.testing)
        } else {
          .
        }
      } %>%
      as_tibble(.)
  }

  return(list(metric1 = metric1,
              metric2 = metric2,
              resamples.values = resamples.values,
              metric1.training = metric1.training,
              metric2.training = metric2.training,
              metric3.training = metric3.training,
              metric1.resamples.boxplots = metric1.resamples.boxplots,
              metric2.resamples.boxplots = metric2.resamples.boxplots,
              metric3.resamples.boxplots = metric3.resamples.boxplots,
              metrics.testing = metrics.testing,
              benchmark.all = benchmark.all
  ))
}


################################################################################
# get_metric_from_resamples
# Helper function for get_model_metrics
################################################################################
get_metric_from_resamples <- function(
  resamples_values, metric, median_sort = FALSE) {

  require(dplyr)

  suffix <- paste0("~", metric)
  # tricky: for arrange, convert string column name to symbol, not quosure
  # https://stackoverflow.com/a/26497839/7769076
  metric.mean <- rlang::sym(paste0(metric,".mean"))
  metric.sd <- paste0(metric,".sd")
  metric.median <- rlang::sym(paste0(metric,".median"))

  sort.metric <- ifelse(median_sort, metric.median, metric.mean)

  resamples_values %>%
    dplyr::select(ends_with(suffix)) %>%
    rename_with(~gsub(suffix, "", .)) %>%
    summarize(across(everything(),
                     list(median = median, mean = mean, sd = sd))) %>%
    # genius tip (.value!): https://stackoverflow.com/a/58880309/7769076
    pivot_longer(
      cols = everything(),
      names_to = c("model", ".value"),
      names_pattern =  "(.+)_(.+$)"
    ) %>%
    set_names(c(
      "model",
      as.character(metric.median),
      as.character(metric.mean),
      metric.sd
    )) %>%
    { # first columns mean+sd if not sorted by median
      if (!median_sort) {
        select(., model, ends_with("mean"), ends_with("sd"), ends_with("median"))
      } else { . }
    } %>%
    {
      if (metric == "RMSE") {
        # tricky: unquote symbol, not quosure
        # tricky: must use . inside inline dplyr code {}
        arrange(., !!sort.metric)
      } else { # for Accuracy, Kappa AND Rsquared: sort by descending order

        arrange(., desc(!!sort.metric))
      }
    }
}


################################################################################
# get_metric_resamples
# Helper function for tidyposterior
################################################################################
get_metric_resamples <- function(resamples_data, metric) {

  resamples_data %>%
    .$values %>%
    as_tibble() %>%
    select(Resample, contains(metric)) %>%
    # tricky: tilde (~) NOT dash (-)
    setNames(gsub(paste0("~", metric), "", names(.))) %>%
    dplyr::rename(id = Resample)
}

################################################################################
# visualize_resamples_boxplots()
# Helper function for get_model_metrics
################################################################################
visualize_resamples_boxplots <- function(
  resamples_values,
  METRIC,
  palette = "Set1",
  descending = FALSE,
  color_count = NULL,
  dot_size = NULL,
  boxplot_fill = "grey95",
  boxplot_color = "grey25",
  colors = NULL,
  exclude_light_hues = NULL
) {
  require(dplyr)
  require(ggplot2)
  require(RColorBrewer)

  # dot size of resamples distribution is indirectly proportional to their #
  if (is.null(dot_size)) dot_size <- 1/logb(nrow(resamples_values), 5)

  # extract the resamples values for selected METRIC (e.g. "Accuracy" or "RMSE")
  resamples.by.metric <- resamples_values %>%
    dplyr::select(ends_with(METRIC)) %>%
    purrr::set_names(~ gsub(paste0("~", METRIC), "", .)) %>%
    drop_na() %>%
    pivot_longer(
      cols = everything(),
      names_to = "model",
      values_to = METRIC,
      names_transform = list(model = as.factor)
    )

  # create HEX color codes from palette with 8+ colors
  ## Source: http://novyden.blogspot.com/2013/09/how-to-expand-color-palette-with-ggplot.html
  color.codes <- brewer.pal(8, palette)

  # remove the first color codes of palette as they have very light hues
  if (!is.null(exclude_light_hues)) {
    color.codes %<>% .[-c(1:exclude_light_hues)]
  }

  # the # colors needed depends on # extracted resamples for selected METRIC
  if (is.null(color_count)) color_count <- ncol(resamples_values)

  # generate the color palette by extrapolation from color.codes to color_count
  color.palette.generated <- colorRampPalette(color.codes)(color_count)

  resamples.boxplots <- resamples.by.metric %>%
    ggplot(aes(
      {
        if (descending) {
          x = reorder(model, desc(!!sym(METRIC)), median)
        } else {
          x = reorder(model, !!sym(METRIC), median)
        }
      },
      y = !!sym(METRIC),
      color = model
    )) +
    geom_boxplot(
      width = 0.7,
      fill = boxplot_fill,
      color = boxplot_color,
      alpha = 0.3
    ) +
    geom_jitter(size = dot_size) +
    coord_flip() +
    scale_color_manual(
      values = if (!is.null(colors)) colors else color.palette.generated
    ) +
    labs(x = "model", y = METRIC) +
    theme_minimal() +
    theme(
      legend.position = "none",
      axis.title = element_text(size = 14),
      axis.text = element_text(size = 14)
    )

  return(resamples.boxplots)
}


#######################################################################
# define string in filename
#######################################################################
logical_string <- function(logical_flag, true_string) {

  if (logical_flag) true_string else NULL

}

#######################################################################
# benchmark algorithms with caret::train
#######################################################################
benchmark_algorithms <- function(

  target_label,
  features_labels,
  training_set,
  testing_set,
  formula_input = FALSE,
  preprocess_configuration = c("center", "scale", "zv"),
  training_configuration,
  impute_method = NULL,
  algorithm_list,
  glm_family = NULL,
  seed = 17,
  cv_repeats,
  try_first = NULL,
  models_list_name = NULL,
  cluster_log = "",
  beep = TRUE,
  push = TRUE) {

  ########################################
  ## 2.3 Select the target & features
  ########################################
  target_label %>% print
  features_labels %>% print

  ########################################
  # 3.2: Select the target & features
  ########################################
  target <- training_set[[target_label]]
  # avoid tibble e.g. for svmRadial: "setting rownames on tibble is deprecated"
  features <- training_set %>% select(features_labels) %>% as.data.frame

  if (!is.null(try_first) & is.numeric(try_first)) {

    target %<>% head(try_first)
    features %<>% head(try_first)
    training_set %<>% head(try_first)

  }

  ########################################
  # 3.3: Train the models
  ########################################
  models.list <- list()

  if (formula_input) {

    print("******** FORMULA interface")

    # define formula
    formula1 <- set_formula(target_label, features_labels)

    system.time(
      models.list <- algorithm_list %>%

        map(function(algorithm_label) {

          print(paste("***", algorithm_label))

          ############ START new cluster for model training
          cluster.new <- clusterOn(outfile_name = cluster_log)
          # stop cluster if training throws error (https://stackoverflow.com/a/41679580/7769076)
          on.exit(if (exists("cluster.new")) { clusterOff(cluster.new) } )

          if (algorithm_label == "rf") {

            model <- train(
              form = formula1,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = TRUE
            )

          } else if (algorithm_label == "ranger") {

            model <- train(
              form = formula1,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = "impurity"
            )

          } else if (algorithm_label == "glm" | algorithm_label == "glmnet") {

            model <- train(
              form = formula1,
              method = algorithm_label,
              family = glm_family,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          } else {

            model <- train(
              form = formula1,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          }
          ############ END model training & STOP cluster
          clusterOff(cluster.new)
          stopImplicitCluster()

          return(model)
        }) %>%
        setNames(algorithm_list)
    ) %T>% {
      if (beep) beepr::beep()
      if (push) push_message(
        time_in_seconds = .["elapsed"],
        algorithm_list = algorithm_list,
        models_list_name = models_list_name
      )
    }
    # categorical variables -> x,y interface
  } else {

    print("******** X Y INTERFACE")

    # transform categorical features by one-hot-encoding for models except rf, ranger, gbm
    # e.g. glmnet expects features as model.matrix (source: https://stackoverflow.com/a/48230658/7769076)
    if (contains_factors(training_set)) {

      formula1 <- set_formula(target_label, features_labels)
      features.onehot <- model.matrix(formula1, data = training_set) %>%
        as.data.frame() %>%
        select(-`(Intercept)`)
      # training.set.onehot <- cbind(target, features.onehot)
    }
    # backup original features before loop to avoid overriding
    features.original <- features
    # training.set.original <- training_set

    system.time(
      models.list <- algorithm_list %>%

        map(function(algorithm_label) {

          print(paste("***", algorithm_label))

          # transform factors by one-hot-encoding for all models except rf, ranger, gbm
          if (contains_factors(training_set) &
              !handles_factors(algorithm_label)
              & !algorithm_label %in% c("svmRadial", "svmLinear")
          ) {

            features <- features.onehot
            # training.set <- training.set.onehot
            print(paste("*** performed one-hot-encoding for model", algorithm_label))

          } else { # no onehot-encoding

            features <- features.original
            # training.set <- training.set.original

          }

          ############ START new cluster for model training
          cluster.new <- clusterOn(outfile_name = cluster_log)
          # stop cluster if training throws error (https://stackoverflow.com/a/41679580/7769076)
          on.exit(if (exists("cluster.new")) { clusterOff(cluster.new) } )

          if (algorithm_label == "rf") {

            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = TRUE
            )
          } else if (algorithm_label == "ranger") {

            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = "impurity"
            )

          } else if (class(target) == "factor" &
                     (algorithm_label == "glm" | algorithm_label == "glmnet")
          ) {

            model <- train(
              x = features,
              y = target,
              method = "glm",
              family = glm_family,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )

          } else if (algorithm_label == "xgbTree" | algorithm_label == "xgbLinear") {

            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              nthread = 1,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )

          } else if (algorithm_label == "svmRadial" | algorithm_label == "svmLinear") {

            # predict() requires kernlab::ksvm object created by formula:
            # https://stackoverflow.com/q/52743663/7769076
            formula.svm <- set_formula(target_label, features_labels)

            model <- train(
              form = formula.svm,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )

          } else {

            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          }
          ############ END model training & STOP cluster
          clusterOff(cluster.new)
          stopImplicitCluster()

          return(model)
        }) %>%
        setNames(algorithm_list)
    ) %T>% {
      if (beep) beepr::beep()
      if (push) push_message(
        time_in_seconds = .["elapsed"],
        algorithm_list = algorithm_list,
        models_list_name = if (!is.null(models_list_name)) models_list_name else NULL
      )
    }
  }

  ########################################
  # Postprocess the models
  ########################################
  # add target.label & testing.set to models.list
  models.list$target.label <- target_label
  models.list$testing.set <- testing_set

  # save the models.list
  if (!is.null(models_list_name)) {

    models.list %>% saveRDS(models_list_name)

    print(paste("model training results saved in", models_list_name))
  }
  print(models.list)
  return(models.list)
}

################################################################################
# Dataset contains Factors
# check if dataset contains categorical features
################################################################################
contains_factors <- function(data) {
  data %>%
    select_if(is.factor) %>%
    names %>%
    {length(.) > 0}
}

################################################################################
# Algorithm handles Factors
# Check if algorithm handles categorical features without one-hot-encoding
################################################################################
handles_factors <- function(algorithm_label) {

  # models that can handle factors instead of one-hot-encoding
  algorithms.handling.factors <- c(
    "rf", "ranger", "gbm", "nnet"
  )

  # check whether imput algorithm handles factors
  algorithm_label %in% algorithms.handling.factors
}

################################################################################
# Get feature set
# From vector of feature labels, generate feature set
################################################################################
get_featureset <- function(data,
                           target_label = NULL,
                           featureset_labels = NULL,
                           select_starts = NULL) {

  data %>%
    dplyr::select(!!rlang::sym(target_label)) %>%

    {
      if (!is.null(featureset_labels)) {
        cbind(.,
              data %>%
                dplyr::select(!!!rlang::syms(featureset_labels))
        )
      } else { . }
    } %>%
    {
      if (!is.null(select_starts)) {

        cbind(.,
              map_dfc(select_starts, function(start_keyword) {
                data %>%
                  select(starts_with(start_keyword))
              })
        )

      } else { . }
    } %>%
    as_tibble()
}

################################################################################
# Get Testing Set Performance
# calculate RMSE for all model objects in model_list
################################################################################
get_testingset_performance <- function(
  models_list, target_label = NULL, testing_set = NULL) {

  # remove target.label + testing.set from models.list
  if (!is.null(models_list$target.label) & !is.null(models_list$testing.set)) {

    target.label <- models_list$target.label
    testing.set <- models_list$testing.set
    models_list %<>% purrr::list_modify("target.label" = NULL, "testing.set" = NULL)

  } else if (!is.null(target_label) & !is.null(testing_set)) {

    target.label <- target_label
    testing.set <- testing_set
  }

  features.labels <- testing.set %>% select(-target.label) %>% names

  observed <- testing.set[[target.label]]

  # do onehot encoding for algorithms that cannot handle factors
  if (contains_factors(testing.set)) {

    formula1 <- set_formula(target.label, features.labels)
    testing.set.onehot <- model.matrix(formula1, data = testing.set) %>%
      as_tibble() %>%
      select(-`(Intercept)`)
  }

  if (is.factor(observed)) {

    models_list %>%
      map(
        function(model_object) {

          # set flag for onehot encoding
          onehot <- FALSE
          # do onehot encoding for algorithms that cannot handle factors
          if (contains_factors(testing.set) &
              !handles_factors(model_object$method) &
              !model_object$method %in% c("svmRadial", "svmLinear")) {

            onehot <- TRUE
          }

          model_object %>%
            # estimate target in the testing set
            predict(newdata = if (onehot) testing.set.onehot else testing.set) %>%
            confusionMatrix(., observed) %>%
            .$overall %>%
            # tricky: convert first to dataframe > can select column names
            map_df(1) %>%
            select(Accuracy, Kappa)
        }
      ) %>%
      bind_rows(.id = "model") %>%
      setNames(c("model", "Acc.testing", "Kappa.testing"))

  } else if (is.numeric(observed)) {

    models_list %>%
      map_df(
        function(model_object) {

          # set flag for onehot encoding
          onehot <- FALSE
          # do onehot encoding for algorithms that cannot handle factors
          if (contains_factors(testing.set) &
              !handles_factors(model_object$method) &
              !model_object$method %in% c("svmRadial", "svmLinear")) {

            onehot <- TRUE
          }

          mean.training.set <- models_list[[1]]$trainingData$.outcome %>% mean

          predicted <- model_object %>%
            # estimate target in the testing set
            predict(newdata = if (onehot) testing.set.onehot else testing.set)

          c(
            # postResample(predicted, observed) %>% .["RMSE"],
            sqrt(mean((observed - predicted)^2)),
            # https://stackoverflow.com/a/36727900/7769076
            sum((predicted - mean.training.set)^2) / sum((observed - mean.training.set)^2),
            # R2 = regression SS / TSS
            ## sum((predicted - mean(predicted))^2) / sum((observed - mean(observed))^2),
            ## ?for centering, the same reference (observed) seems to be better?
            sum((predicted - mean(observed))^2) / sum((observed - mean(observed))^2),
            # postResample(predicted, observed) %>% .[("Rsquared")]
            cor(predicted, observed)^2
          )
        }) %>%
      t %>%
      as_tibble(rownames = "model") %>%
      dplyr::rename(RMSE.testing = V1, R2.testing = V2,
             R2.testing2 = V3,  R2.postResample= V4) %>%
      arrange(RMSE.testing) %>%
      as.data.frame
  }
}

################################################################################
# Visualize variable importance
# input caret::train object
################################################################################
visualize_importance <- function (
  model_object, # caret::train object
  relative = FALSE, # calculate relative importance scores (not normalized)
  axis_label = NULL, # label for vertical axis
  axis_tick_labels = NULL, # labels for items/facets/factors
  text_labels = FALSE, # labels showing numeric scores next to bar
  axis_limit = NULL, # max. axis score displayed
  width = 4, height = 3, dpi = 300, # specs for saved plot
  fill_color = "#114151",
  font_size = 10,
  save_label = "" # filename for saved plot
) {

  require(caret)
  require(gbm)
  require(dplyr)
  require(ggplot2)

  # calculate feature importance
  importance_object <- model_object %>% caret::varImp()

  unit.label <- ifelse(relative, "%RI", "importance")
  unit.variable <- rlang::sym(unit.label)


  if (class(importance_object) == "varImp.train") {
    importance_object %<>% .$importance
  }
  if (!hasName(importance_object, "rowname")) {
    importance_object %<>% rownames_to_column()
  }

  importance.table <- importance_object %>%
    dplyr::rename(variable = rowname, importance = Overall) %>%
    arrange(desc(importance)) %>%
    {
      if (relative) {
        dplyr::mutate(., `%RI` = importance/sum(importance)*100) %>%
          select(variable, `%RI`)
      } else {
        .
      }
    }

  importance.plot <- importance.table %>%
    set_names(c("variable", unit.label)) %>%
    ggplot(data = .,
           aes(x = reorder(variable, !!unit.variable), y = !!unit.variable)) +
    theme_minimal() +
    geom_bar(stat = "identity", fill = fill_color) +
    {
      if (text_labels) {
        geom_text(aes(label = round(!!unit.variable, digits = 2)),
                  position = position_dodge(width = 5),
                  hjust = -0.1,
                  check_overlap = TRUE,
                  # tricky: font size must be scaled down by ggplot2:::.pt
                  # https://stackoverflow.com/a/17312440/7769076
                  size = font_size / (ggplot2:::.pt * 1.1)
        )
      }
    } +
    coord_flip() +
    theme(axis.title = element_text(size = font_size),
          axis.text = element_text(size = font_size)) +
    {
      if (!is.null(axis_limit)) {
        scale_y_continuous(expand = c(0, 0),
                           limits = c(0, axis_limit))
      }
    } +
    {
      if (!is.null(axis_tick_labels)) {
        scale_x_discrete(labels = axis_tick_labels)
      }
    } +
    labs(
      x = axis_label,
      y = unit.label
    )

  if (save_label != "") {
    ggsave(
      filename = save_label,
      plot = importance.plot,
      dpi = dpi,
      width = width,
      height = height
    )
  }

  return(
    list(
      importance.table = importance.table,
      importance.plot = importance.plot
    ))
}
agilebean/machinelearningtools documentation built on May 2, 2023, 7:38 p.m.